Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph RNNT: Grid- and Compose-Transducer. W-Transducer loss #6168

Merged
merged 53 commits into from
May 26, 2023

Conversation

artbataev
Copy link
Collaborator

@artbataev artbataev commented Mar 10, 2023

What does this PR do?

Implement Grid- and Compose-Transducer along with W-Transducer loss according to the paper "Powerful and Extensible WFST Framework for Rnn-Transducer Losses" https://ieeexplore.ieee.org/document/10096679 (https://arxiv.org/abs/2303.10384)

Collection: [ASR]

Changelog

  • GraphTransducerLossBase abstract class with the interface for Graph-based loses
  • RNN-T implementation in GraphRnntLoss with tests
  • W-Transducer implementation in GraphWTransducerLoss with tests
  • add GraphRnntLoss + GraphWTransducerLoss to RNN-T loss resolver

Usage

# training some model with Graph-RNNT loss
python examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py \
    <other params> \
    model.loss.loss_name=graph_rnnt \
    ++model.loss.graph_rnnt_kwargs.double_scores=true

# training some model with W-Transducer loss
python examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py \
    <other params> \
    model.loss.loss_name=graph_w_transducer \
    ++model.loss.graph_w_transducer_kwargs.eps_weight=0

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@github-actions github-actions bot added the ASR label Mar 10, 2023
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@github-actions
Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Apr 11, 2023
@GNroy GNroy removed the stale label Apr 11, 2023
@github-actions
Copy link
Contributor

github-actions bot commented May 6, 2023

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label May 6, 2023
@artbataev artbataev removed the stale label May 7, 2023
@github-actions github-actions bot added the core Changes to NeMo Core label May 19, 2023
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev requested a review from GNroy May 24, 2023 14:12
@artbataev artbataev marked this pull request as ready for review May 24, 2023 14:12
GNroy
GNroy previously approved these changes May 26, 2023
Copy link
Collaborator

@GNroy GNroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, please see comments.

nemo/collections/asr/parts/k2/w_transducer.py Outdated Show resolved Hide resolved
nemo/collections/asr/parts/k2/w_transducer.py Show resolved Hide resolved
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
GNroy
GNroy previously approved these changes May 26, 2023
@artbataev artbataev requested a review from titu1994 May 26, 2023 15:57
titu1994
titu1994 previously approved these changes May 26, 2023
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Minor comment about the new arg, make it explicit rather than a default. Easy to miss otherwise

@@ -109,6 +119,20 @@ class RNNTLossConfig:
is_available=True,
installation_msg="Pure Pytorch implementation of Multiblank RNN-T loss. Slow and for debugging purposes only.",
),
"graph_w_transducer": RNNTLossConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the rest of the configs to explicitly show force fp32 is true

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed (I updated all the configs)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev dismissed stale reviews from titu1994 and GNroy via 607601e May 26, 2023 18:56
@artbataev artbataev merged commit f1b5eae into main May 26, 2023
@artbataev artbataev deleted the transucer_compose_grid_wildcard branch May 26, 2023 20:55
@artbataev artbataev mentioned this pull request May 26, 2023
8 tasks
hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
)

* add GraphTransducerLossBase abstract class with the interface for Graph-based loses
* add RNN-T implementation in GraphRnntLoss with tests
* add W-Transducer implementation in GraphWTransducerLoss with tests
* add GraphRnntLoss + GraphWTransducerLoss to RNN-T loss resolver

---------

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ASR core Changes to NeMo Core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants